{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Seq2Seq模型实现文本翻译\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/tutorials/application/zh_cn/nlp/mindspore_sequence_to_sequence.ipynb)&emsp;[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/tutorials/application/zh_cn/nlp/mindspore_sequence_to_sequence.py)&emsp;[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/master/tutorials/application/source_zh_cn/nlp/sequence_to_sequence.ipynb)\n",
    "\n",
    "## 概述\n",
    "\n",
    "序列到序列模型（sequence to sequence model），又名Seq2Seq模型。它是一种循环神经网络（Recurrent Neural Network，RNN）的变种，突破了原本RNN模型对于输入和输出序列长度的限制，做到将输入序列映射到另一个长度不同的输出序列，因此常用于机器翻译的任务。\n",
    "\n",
    "Seq2Seq模型一般结构为编码器（encoder）+ 解码器（decoder），前者负责把输入序列编码成一个固定长度的向量，后者将这个向量转化为可变长度的向量。\n",
    "\n",
    "![avatar1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/nlp/images/seq2seq_1.png)\n",
    "\n",
    "> 图片来源：\n",
    ">\n",
    "> https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb\n",
    "\n",
    "后来，人们在encoder-decoder的基础上引入了注意力机制（attention），使模型在各个任务中的表现更为出色。\n",
    "\n",
    "## 数据准备\n",
    "\n",
    "我们本次使用的数据集为**Multi30K数据集**，它是一个大规模的图像-文本数据集，包含30K+图片，每张图片对应两类不同的文本描述：\n",
    "\n",
    "- 英语描述，及对应的德语翻译；\n",
    "- 五个独立的、非翻译而来的英语和德语描述，描述中包含的细节并不相同；\n",
    "\n",
    "因其收集的不同语言对于图片的描述相互独立，所以训练出的模型可以更好地适用于有噪声的多模态内容。\n",
    "\n",
    "![avatar2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/nlp/images/seq2seq_2.png)\n",
    "\n",
    "> 图片来源：\n",
    ">\n",
    "> Elliott, D., Frank, S., Sima’an, K., & Specia, L. (2016). Multi30K: Multilingual English-German Image Descriptions. CoRR, 1605.00459.\n",
    "\n",
    "首先，我们需要安装如下依赖：\n",
    "\n",
    "- 分词工具：`pip install spacy`\n",
    "- 德语/英语分词器：`python -m spacy download de_core_news_sm`，`python -m spacy download en_core_web_sm`\n",
    "- BLEU Score计算：`pip install nltk`\n",
    "\n",
    "### 数据下载模块\n",
    "\n",
    "使用`download`进行数据下载，并将`tar.gz`文件解压到指定文件夹。\n",
    "\n",
    "下载好的数据集目录结构如下：\n",
    "\n",
    "```text\n",
    "home_path/.mindspore_examples\n",
    "├─test\n",
    "│      test2016.de\n",
    "│      test2016.en\n",
    "│      test2016.fr\n",
    "│\n",
    "├─train\n",
    "│      train.de\n",
    "│      train.en\n",
    "│\n",
    "└─valid\n",
    "        val.de\n",
    "        val.en\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Replace is False and data exists, so doing nothing. Use replace=True to re-download the data.\n",
      "Replace is False and data exists, so doing nothing. Use replace=True to re-download the data.\n",
      "Replace is False and data exists, so doing nothing. Use replace=True to re-download the data.\n"
     ]
    }
   ],
   "source": [
    "from download import download\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "# 训练、验证、测试数据集下载地址\n",
    "urls = {\n",
    "    'train': 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',\n",
    "    'valid': 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',\n",
    "    'test': 'http://www.quest.dcs.shef.ac.uk/wmt17_files_mmt/mmt_task1_test2016.tar.gz'\n",
    "}\n",
    "\n",
    "# 指定保存路径为 `home_path/.mindspore_examples`\n",
    "cache_dir = Path.home() / '.mindspore_examples'\n",
    "\n",
    "train_path = download(urls['train'], os.path.join(cache_dir, 'train'), kind='tar.gz')\n",
    "valid_path = download(urls['valid'], os.path.join(cache_dir, 'valid'), kind='tar.gz')\n",
    "test_path = download(urls['test'], os.path.join(cache_dir, 'test'), kind='tar.gz')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据预处理\n",
    "\n",
    "在使用数据进行模型训练等操作时，我们需要对数据进行预处理，流程如下：\n",
    "\n",
    "1. 加载数据集，目前数据为句子形式的文本，需要进行分词，即将句子拆解为单独的词元（token，可以为字符或者单词）；\n",
    "    - 分词可以使用`spaCy`创建分词器（tokenizer）：`de_core_news_sm`，`en_core_web_sm`，需要手动下载；\n",
    "    - 分词后，去除多余的空格，统一大小写等；\n",
    "2. 将每个词元映射到从0开始的数字索引中（为节约存储空间，可过滤掉词频低的词元），词元和数字索引所构成的集合叫做词典（vocabulary）；\n",
    "3. 添加特殊占位符，标明序列的起始与结束，统一序列长度，并创建数据迭代器；\n",
    "\n",
    "#### 数据加载器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/miniconda3/envs/mindspore/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "[ERROR] ME(39942:139912596019008,MainProcess):2023-03-14-04:12:41.490.398 [mindspore/run_check/_check_version.py:226] Cuda ['10.1', '11.1', '11.6'] version(libcu*.so need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches, please refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[ERROR] ME(39942:139912596019008,MainProcess):2023-03-14-04:12:41.541.563 [mindspore/run_check/_check_version.py:226] Cuda ['10.1', '11.1', '11.6'] version(libcudnn*.so need by mindspore-gpu) is not found, please confirm that the path of cuda is set to the env LD_LIBRARY_PATH, or check whether the CUDA version in wheel package and the CUDA runtime in current device matches, please refer to the installation guidelines: https://www.mindspore.cn/install\n"
     ]
    }
   ],
   "source": [
    "import spacy\n",
    "from functools import partial\n",
    "from mindnlp.transforms import BasicTokenizer\n",
    "\n",
    "class Multi30K():\n",
    "    \"\"\"Multi30K数据集加载器\n",
    "\n",
    "    加载Multi30K数据集并处理为一个Python迭代对象。\n",
    "\n",
    "    \"\"\"\n",
    "    def __init__(self, path):\n",
    "        self.data = self._load(path)\n",
    "\n",
    "    def _load(self, path):\n",
    "        tokenizer = BasicTokenizer()\n",
    "        def tokenize(text, spacy_lang):\n",
    "            # 去除多余空格，统一大小写\n",
    "            text = text.rstrip()\n",
    "            # return [tok.text.lower() for tok in spacy_lang.tokenizer(text)]\n",
    "            return tokenizer(text).tolist()\n",
    "\n",
    "        # 加载英、德语分词器\n",
    "        tokenize_de = partial(tokenize, spacy_lang=spacy.load('de_core_news_sm'))\n",
    "        tokenize_en = partial(tokenize, spacy_lang=spacy.load('en_core_web_sm'))\n",
    "\n",
    "        # 读取Multi30K数据，并进行分词\n",
    "        members = {i.split('.')[-1]: i for i in os.listdir(path)}\n",
    "        de_path = os.path.join(path, members['de'])\n",
    "        en_path = os.path.join(path, members['en'])\n",
    "        with open(de_path, 'r') as de_file:\n",
    "            de = de_file.readlines()[:-1]\n",
    "            de = [tokenize_de(i) for i in de]\n",
    "        with open(en_path, 'r') as en_file:\n",
    "            en = en_file.readlines()[:-1]\n",
    "            en = [tokenize_en(i) for i in en]\n",
    "\n",
    "        return list(zip(de, en))\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset, valid_dataset, test_dataset = Multi30K(train_path), Multi30K(valid_path), Multi30K(test_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对解压和分词结果进行测试，打印测试数据集第一组英德语文本，可以看到每一个单词和标点符号已经被单独分离出来。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "de = ['Ein', 'Mann', 'mit', 'einem', 'orangefarbenen', 'Hut', ',', 'der', 'etwas', 'anstarrt', '.']\n",
      "en = ['A', 'man', 'in', 'an', 'orange', 'hat', 'starring', 'at', 'something', '.']\n"
     ]
    }
   ],
   "source": [
    "for de, en in test_dataset:\n",
    "    print(f'de = {de}')\n",
    "    print(f'en = {en}')\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 词典"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Vocab:\n",
    "    \"\"\"通过词频字典，构建词典\"\"\"\n",
    "\n",
    "    special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']\n",
    "\n",
    "    def __init__(self, word_count_dict, min_freq=1):\n",
    "        self.word2idx = {}\n",
    "        for idx, tok in enumerate(self.special_tokens):\n",
    "            self.word2idx[tok] = idx\n",
    "\n",
    "        # 过滤低词频的词元\n",
    "        filted_dict = {\n",
    "            w: c\n",
    "            for w, c in word_count_dict.items() if c >= min_freq\n",
    "        }\n",
    "        for w, _ in filted_dict.items():\n",
    "            self.word2idx[w] = len(self.word2idx)\n",
    "\n",
    "        self.idx2word = {idx: word for word, idx in self.word2idx.items()}\n",
    "\n",
    "        self.bos_idx = self.word2idx['<bos>']  # 特殊占位符：序列开始\n",
    "        self.eos_idx = self.word2idx['<eos>']  # 特殊占位符：序列结束\n",
    "        self.pad_idx = self.word2idx['<pad>']  # 特殊占位符：补充字符\n",
    "        self.unk_idx = self.word2idx['<unk>']  # 特殊占位符：低词频词元或未曾出现的词元\n",
    "\n",
    "    def _word2idx(self, word):\n",
    "        \"\"\"单词映射至数字索引\"\"\"\n",
    "        if word not in self.word2idx:\n",
    "            return self.unk_idx\n",
    "        return self.word2idx[word]\n",
    "\n",
    "    def _idx2word(self, idx):\n",
    "        \"\"\"数字索引映射至单词\"\"\"\n",
    "        if idx not in self.idx2word:\n",
    "            raise ValueError('input index is not in vocabulary.')\n",
    "        return self.idx2word[idx]\n",
    "\n",
    "    def encode(self, word_or_list):\n",
    "        \"\"\"将单个单词或单词数组映射至单个数字索引或数字索引数组\"\"\"\n",
    "        if isinstance(word_or_list, list):\n",
    "            return [self._word2idx(i) for i in word_or_list]\n",
    "        return self._word2idx(word_or_list)\n",
    "\n",
    "    def decode(self, idx_or_list):\n",
    "        \"\"\"将单个数字索引或数字索引数组映射至单个单词或单词数组\"\"\"\n",
    "        if isinstance(idx_or_list, list):\n",
    "            return [self._idx2word(i) for i in idx_or_list]\n",
    "        return self._idx2word(idx_or_list)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.word2idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过自定义词频字典进行测试，我们可以看到词典已去除词频少于2的词元c，并加入了默认的四个特殊占位符，故词典整体长度为：4 - 1 + 4 = 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word_count = {'a':20, 'b':10, 'c':1, 'd':2}\n",
    "\n",
    "vocab = Vocab(word_count, min_freq=2)\n",
    "len(vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用`collections`中的`Counter`和`OrderedDict`统计英/德语每个单词在整体文本中出现的频率。构建词频字典，然后再将词频字典转为词典。\n",
    "\n",
    "在分配数字索引时有一个小技巧：常用的词元对应数值较小的索引，这样可以节约空间。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter, OrderedDict\n",
    "\n",
    "def build_vocab(dataset):\n",
    "    de_words, en_words = [], []\n",
    "    for de, en in dataset:\n",
    "        de_words.extend(de)\n",
    "        en_words.extend(en)\n",
    "\n",
    "    de_count_dict = OrderedDict(sorted(Counter(de_words).items(), key=lambda t: t[1], reverse=True))\n",
    "    en_count_dict = OrderedDict(sorted(Counter(en_words).items(), key=lambda t: t[1], reverse=True))\n",
    "\n",
    "    return Vocab(de_count_dict, min_freq=2), Vocab(en_count_dict, min_freq=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unique tokens in de vocabulary: 8050\n"
     ]
    }
   ],
   "source": [
    "de_vocab, en_vocab = build_vocab(train_dataset)\n",
    "print('Unique tokens in de vocabulary:', len(de_vocab))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 数据迭代器\n",
    "\n",
    "数据预处理的最后一步是创建数据迭代器，我们在进一步处理数据（包括批处理，添加起始和终止符号，统一序列长度）后，将数据以张量的形式返回。\n",
    "\n",
    "创建数据迭代器需要如下参数：\n",
    "\n",
    "- `dataset`：分词后的数据集\n",
    "- `de_vocab`：德语词典\n",
    "- `en_vocab`：英语词典\n",
    "- `batch_size`：批量大小，即一个batch中包含多少个序列\n",
    "- `max_len`：序列最大长度，为最长有效文本长度 + 2（序列开始、序列结束占位符），如不满则补齐，如超过则丢弃\n",
    "- `drop_remainder`：是否在最后一个batch未满时，丢弃该batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "\n",
    "class Iterator():\n",
    "    \"\"\"创建数据迭代器\"\"\"\n",
    "    def __init__(self, dataset, de_vocab, en_vocab, batch_size, max_len=32, drop_reminder=False):\n",
    "        self.dataset = dataset\n",
    "        self.de_vocab = de_vocab\n",
    "        self.en_vocab = en_vocab\n",
    "\n",
    "        self.batch_size = batch_size\n",
    "        self.max_len = max_len\n",
    "        self.drop_reminder = drop_reminder\n",
    "\n",
    "        length = len(self.dataset) // batch_size\n",
    "        self.len = length if drop_reminder else length + 1  # 批量数量\n",
    "\n",
    "    def __call__(self):\n",
    "        def pad(idx_list, vocab, max_len):\n",
    "            \"\"\"统一序列长度，并记录有效长度\"\"\"\n",
    "            idx_pad_list, idx_len = [], []\n",
    "            # 当前序列度超过最大长度时，将超出的部分丢弃；当前序列长度小于最大长度时，用占位符补齐\n",
    "            for i in idx_list:\n",
    "                if len(i) > max_len - 2:\n",
    "                    idx_pad_list.append(\n",
    "                        [vocab.bos_idx] + i[:max_len-2] + [vocab.eos_idx]\n",
    "                    )\n",
    "                    idx_len.append(max_len)\n",
    "                else:\n",
    "                    idx_pad_list.append(\n",
    "                        [vocab.bos_idx] + i + [vocab.eos_idx] + [vocab.pad_idx] * (max_len - len(i) - 2)\n",
    "                    )\n",
    "                    idx_len.append(len(i) + 2)\n",
    "            return idx_pad_list, idx_len\n",
    "\n",
    "        def sort_by_length(src, trg):\n",
    "            \"\"\"对德/英语的字段长度进行排序\"\"\"\n",
    "            data = zip(src, trg)\n",
    "            data = sorted(data, key=lambda t: len(t[0]), reverse=True)\n",
    "            return zip(*list(data))\n",
    "\n",
    "        def encode_and_pad(batch_data, max_len):\n",
    "            \"\"\"将批量中的文本数据转换为数字索引，并统一每个序列的长度\"\"\"\n",
    "            # 将当前批量数据中的词元转化为索引\n",
    "            src_data, trg_data = zip(*batch_data)\n",
    "            src_idx = [self.de_vocab.encode(i) for i in src_data]\n",
    "            trg_idx = [self.en_vocab.encode(i) for i in trg_data]\n",
    "\n",
    "            # 统一序列长度\n",
    "            src_idx, trg_idx = sort_by_length(src_idx, trg_idx)\n",
    "            src_idx_pad, src_len = pad(src_idx, de_vocab, max_len)\n",
    "            trg_idx_pad, _ = pad(trg_idx, en_vocab, max_len)\n",
    "\n",
    "            return src_idx_pad, src_len, trg_idx_pad\n",
    "\n",
    "        for i in range(self.len):\n",
    "            # 获取当前批量的数据\n",
    "            if i == self.len - 1 and not self.drop_reminder:\n",
    "                batch_data = self.dataset[i * self.batch_size:]\n",
    "            else:\n",
    "                batch_data = self.dataset[i * self.batch_size: (i+1) * self.batch_size]\n",
    "\n",
    "            src_idx, src_len, trg_idx = encode_and_pad(batch_data, self.max_len)\n",
    "            # 将序列数据转换为tensor\n",
    "            yield mindspore.Tensor(src_idx, mindspore.int32), \\\n",
    "                mindspore.Tensor(src_len, mindspore.int32), \\\n",
    "                mindspore.Tensor(trg_idx, mindspore.int32)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_iterator = Iterator(train_dataset, de_vocab, en_vocab, batch_size=128, max_len=32, drop_reminder=True)\n",
    "valid_iterator = Iterator(valid_dataset, de_vocab, en_vocab, batch_size=128, max_len=32, drop_reminder=False)\n",
    "test_iterator = Iterator(test_dataset, de_vocab, en_vocab, batch_size=128, max_len=32, drop_reminder=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型构建"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 编码器（Encoder）\n",
    "\n",
    "在编码器中，我们输入一个序列$X=\\{x_1, x_2, ..., x_T\\}$，在embedding层将其转化为向量，循环计算隐藏状态$H=\\{h_1, h_2, ..., h_T\\}$，并在最后的隐藏状态中返回上下文向量$z=h_T$。\n",
    "\n",
    "实现编码器的方式有很多种，在这里我们使用的是门控循环单元模型（Gated Rrecurrent Units, GRU）。它在原始RNN的基础上引入了门机制（gate mechanism），用以控制输入隐藏状态和从隐藏状态输出的信息。其中，更新门（update gate， 又称记忆门，一般用$z_t$表示）用于控制前一时刻的状态信息$h_{t-1}$被带入到当前状态$h_t$中的程度。重置门（reset gate，一般用$r_t$表示）控制前一状态$h_t$有多少信息被写入到当前候选集$n_t$上。\n",
    "\n",
    "$$h_t = \\text{RNN}(e(x_t), h_{t-1})$$\n",
    "\n",
    "在进行文本翻译类任务时，我们一般使用双向GRU，即在训练中同时考虑当前词语之前及之后的文本内容。双向GRU的每层由两个RNN构成，前向RNN由左至右循环计算隐藏状态，反向RNN从右至左计算隐藏状态，公式表达如下：\n",
    "\n",
    "$$\\begin{align*}\n",
    "h_t^\\rightarrow &= \\text{EncoderGRU}^\\rightarrow(e(x_t^\\rightarrow),h_{t-1}^\\rightarrow)\\\\\n",
    "h_t^\\leftarrow &= \\text{EncoderGRU}^\\leftarrow(e(x_t^\\leftarrow),h_{t-1}^\\leftarrow)\n",
    "\\end{align*}$$\n",
    "\n",
    "每个RNN网络在观察到句子中的最后一个词后，输出一个上下文向量，前向RNN的输出为$z^\\rightarrow=h_T^\\rightarrow$，反向RNN的输出为$z^\\leftarrow=h_T^\\leftarrow$。\n",
    "\n",
    "![avatar3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/nlp/images/seq2seq_3.png)\n",
    "\n",
    "> 图片来源：\n",
    ">\n",
    "> https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb\n",
    "\n",
    "编码器最终会返回两项：`outputs`和`hidden`。\n",
    "\n",
    "- `outputs`为双向GRU最上层隐藏状态，形状为\\[max_len, batch_size, hid_dim * num_directions\\]。以$t=1$时刻为例，其对应的output为前向RNN中$t=1$时刻最上层隐藏状态和反向RNN中$t=T$时刻的结合，即$h_1 = [h_1^\\rightarrow; h_{T}^\\leftarrow]$；\n",
    "\n",
    "- `hidden`表示每层的最终隐藏状态，即上文提到的上下文向量。后续将作为编码器初始时刻的隐藏状态$s_0$，但由于编码器（decoder）的结构并不是双向的，仅仅需要一个上下文向量$z$，为了与之对应，我们将编码器中的两个向量组合起来，放入全连接层$g$中，并最后使用激活函数$tanh$；\n",
    "\n",
    "$$z=\\tanh(g(h_T^\\rightarrow, h_T^\\leftarrow)) = \\tanh(g(z^\\rightarrow, z^\\leftarrow)) = s_0$$\n",
    "\n",
    "MindSpore为大家提供了GRU的接口，可以在编码器搭建中直接调用，通过设置参数`bidirectional=True`使用双向GRU。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "\n",
    "class Encoder(nn.Cell):\n",
    "    def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(input_dim, emb_dim)  # Embedding层\n",
    "        self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional=True)  # 双向GRU层\n",
    "        self.fc = nn.Dense(enc_hid_dim * 2, dec_hid_dim)  # 全连接层\n",
    "\n",
    "        self.dropout = nn.Dropout(p=dropout)  # dropout，防止过拟合\n",
    "\n",
    "    def construct(self, src, src_len):\n",
    "        \"\"\"构建编码器\n",
    "\n",
    "        Args:\n",
    "            src: 源序列，为已经转换为数字索引并统一长度的序列；shape = [src len, batch_size]\n",
    "            src_len: 有效长度；shape = [batch_size, ]\n",
    "        \"\"\"\n",
    "\n",
    "        # 将输入源序列转化为向量，并进行暂退（dropout）\n",
    "        # shape = [src len, batch size, emb dim]\n",
    "        embedded = self.dropout(self.embedding(src))\n",
    "        # 计算输出\n",
    "        # shape = [src len, batch size, enc hid dim*2]\n",
    "        outputs, hidden = self.rnn(embedded, seq_length=src_len)\n",
    "        # 为适配解码器，合并两个上下文函数\n",
    "        # shape = [batch size, dec hid dim]\n",
    "        hidden = ops.tanh(self.fc(ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)))\n",
    "\n",
    "        return outputs, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 注意力层（Attention）\n",
    "\n",
    "在机器翻译中，每个生成的词可能对应源句子中不同的词，而传统的无注意力机制的Seq2Seq模型更偏向于关注句子中的最后一个词。为了进一步优化模型，我们引入了注意力机制。\n",
    "\n",
    "注意力机制便是赋予源句子和目标句子中对应的词以更高的权重，它整合了我们目前为止编码与解码的所有信息，并输出一个表示注意力权重的向量$a_t$，用来决定在下一步的预测$\\hat{y}_{t+}$中应该给予哪些词更高的关注度。\n",
    "\n",
    "![avatar4](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/nlp/images/seq2seq_4.png)\n",
    "\n",
    "> 图片来源：\n",
    ">\n",
    "> https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb\n",
    "\n",
    "首先，我们需要明确编码器中的每一个隐藏状态和解码器中上一个时刻隐藏状态之间的匹配程度$E_t$。\n",
    "\n",
    "截止到当前的时刻$t$，编码器（encoder）中的所有信息为全部前向和后向RNN的隐藏状态的组合$H$，是一个有$T$个张量的序列；解码器（decoder）中的所有信息为上一时刻的隐藏状态$s_{t-1}$，是一个单独的张量。为了统一二者的维度，我们需要将解码器中上一时刻的隐藏状态$s_{t-1}$重复$T$次，接着把处理好的解码器信息与编码器信息堆叠起来，并输入到线性层`att`和激活函数$\\text{tanh}$中，计算编码器与解码器隐藏状态之间的能量$E_t$。\n",
    "\n",
    "$$E_t = \\tanh(\\text{attn}(s_{t-1}, H))$$\n",
    "\n",
    "当前$E_t$的每个batch中tensor的形状为\\[dec hid dim, src len\\]，但是注意最终的注意力权重是需要作用在源序列之上的，所以注意力权重的维度也应该与源句子的维度\\[src len\\]相对应。为此，我们引入了一个可学习的张量$v$。\n",
    "\n",
    "$$\\hat{a}_t = v E_t$$\n",
    "\n",
    "我们可以将$v$看作是所有编码器隐藏状态的加权和的权重，简单来说便是对源序列中的每个词的关注程度。$v$的参数是随机初始化的，它会在反向传播中与模型的其余部分一起学习。此外，$v$并不依赖于时间，所以在解码中每个时间步长使用的$v$是一致的。\n",
    "\n",
    "最终，我们使用$\\text{softmax}$函数，来保证注意力向量$a_t$中每一个元素的大小都在0-1之间，并且所有元素加和为1。\n",
    "\n",
    "$$a_t = \\text{softmax}(\\hat{a_t})$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Attention(nn.Cell):\n",
    "    def __init__(self, enc_hid_dim, dec_hid_dim):\n",
    "        super().__init__()\n",
    "\n",
    "        # attention线性层\n",
    "        self.attn = nn.Dense((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)\n",
    "        # v， 用不带有bias的线性层表示\n",
    "        # shape = [1, dec hid dim]\n",
    "        self.v = nn.Dense(dec_hid_dim, 1, has_bias=False)\n",
    "\n",
    "    def construct(self, hidden, encoder_outputs, mask):\n",
    "        \"\"\"Attention层\n",
    "\n",
    "        Args:\n",
    "            hidden: 解码器上一个时刻的隐藏状态；shape = [batch size, dec hid dim]\n",
    "            encoder_outputs: 编码器的输出，前向与反向RNN的隐藏状态；shape = [src len, batch size, enc hid dim * 2]\n",
    "            mask: 将<pad>占位符的注意力权重替换为0或者很小的数值；shape = [batch size, src len]\n",
    "        \"\"\"\n",
    "\n",
    "        src_len = encoder_outputs.shape[0]\n",
    "\n",
    "        # 重复解码器隐藏状态src len次，对齐维度\n",
    "        # shape = [batch size, src len, dec hid dim]\n",
    "        hidden = ops.tile(hidden.expand_dims(1), (1, src_len, 1))\n",
    "\n",
    "        # 将编码器输出中的第1、2维度进行交换，对齐维度\n",
    "        # shape = [batch size, src len, enc hid dim*2]\n",
    "        encoder_outputs = encoder_outputs.transpose(1, 0, 2)\n",
    "\n",
    "        # 计算E_t\n",
    "        # shape = [batch size, src len, dec hid dim]\n",
    "        energy = ops.tanh(self.attn(ops.concat((hidden, encoder_outputs), axis=2)))\n",
    "\n",
    "        # 计算v * E_t\n",
    "        # shape = [batch size, src len]\n",
    "        attention = self.v(energy).squeeze(2)\n",
    "\n",
    "        # 不需要考虑序列中<pad>占位符的注意力权重\n",
    "        attention = attention.masked_fill(mask == 0, -1e10)\n",
    "\n",
    "        return ops.softmax(attention, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 解码器（Decoder）\n",
    "\n",
    "解码器中包含了上述的注意力层，在获得注意力权重向量$a_t$后，我们将其应用在编码器的隐藏状态$H$上，得到一个表示编码器隐藏状态加权和的向量$w_t$。\n",
    "\n",
    "$$w_t = a_t H$$\n",
    "\n",
    "我们将该向量$w_t$，连同embedding后的输入$d(y_t)$，上一时刻的隐藏状态$s_{t-1}$，一起放入编码器的RNN网络中，并将输出送入线性层$f$，得到关于目标句子中下一时刻出现的单词的预测。\n",
    "\n",
    "$$s_t = \\text{DecoderGRU}(d(y_t), w_t, s_{t-1})$$\n",
    "\n",
    "$$\\hat{y}_{t+1} = f(d(y_t), w_t, s_t)$$\n",
    "\n",
    "![avatar5](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/application/source_zh_cn/nlp/images/seq2seq_5.png)\n",
    "\n",
    "> 图片来源：\n",
    ">\n",
    "> https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Cell):\n",
    "    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):\n",
    "        super().__init__()\n",
    "        self.output_dim = output_dim\n",
    "        self.attention = attention\n",
    "\n",
    "        self.embedding = nn.Embedding(output_dim, emb_dim)\n",
    "        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)\n",
    "        self.fc_out = nn.Dense((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "\n",
    "    def construct(self, inputs, hidden, encoder_outputs, mask):\n",
    "        \"\"\"构建解码器\n",
    "\n",
    "        Args:\n",
    "            input: 输入的单词；shape = [batch size]\n",
    "            hidden: 解码器上一时刻的隐藏状态；shape = [batch size, dec hid dim]\n",
    "            encoder_outputs: 编码器的输出，前向与反向RNN的隐藏状态；shape = [src len, batch size, enc hid dim * 2]\n",
    "            mask: 将<pad>占位符的注意力权重替换为0或者很小的数值；shape = [batch size, src len]\n",
    "        \"\"\"\n",
    "\n",
    "        # 为输入增加额外维度\n",
    "        # shape = [1, batch size]\n",
    "        inputs = inputs.expand_dims(0)\n",
    "\n",
    "        # 输入词的embedding输出， d(y_t)\n",
    "        # shape = [1, batch size, emb dim]\n",
    "        embedded = self.dropout(self.embedding(inputs))\n",
    "\n",
    "        # 注意力权重向量, a_t\n",
    "        # shape = [batch size, src len]\n",
    "        a = self.attention(hidden, encoder_outputs, mask)\n",
    "\n",
    "        # 为注意力权重增加额外维度\n",
    "        # shape = [batch size, 1, src len]\n",
    "        a = a.expand_dims(1)\n",
    "\n",
    "        # 将编码器隐藏状态中的第1、2维度进行交换\n",
    "        # shape = [batch size, src len, enc hid dim * 2]\n",
    "        encoder_outputs = encoder_outputs.transpose(1, 0, 2)\n",
    "\n",
    "        # 计算w_t\n",
    "        # shape = [batch size, 1, enc hid dim * 2]\n",
    "        weighted = ops.bmm(a, encoder_outputs)\n",
    "\n",
    "        # 将w_t的第1、2维度进行交换\n",
    "        # shape = [1, batch size, enc hid dim * 2]\n",
    "        weighted = weighted.transpose(1, 0, 2)\n",
    "\n",
    "        # 将emdedded与weighted堆叠在一起，后输入进RNN层\n",
    "        # rnn_input shape = [1, batch size, (enc hid dim * 2) + emb dim]\n",
    "        # output shape = [seq len = 1, batch size, dec hid dim * n directions]\n",
    "        # hidden shape = [n layers (1) * n directions (1) = 1, batch size, dec hid dim]\n",
    "        rnn_input = ops.concat((embedded, weighted), axis=2)\n",
    "        output, hidden = self.rnn(rnn_input, hidden.expand_dims(0))\n",
    "\n",
    "        # 去除多余的第1维度\n",
    "        embedded = embedded.squeeze(0)\n",
    "        output = output.squeeze(0)\n",
    "        weighted = weighted.squeeze(0)\n",
    "\n",
    "        # 将embedded，weighted和hidden堆叠起来，并输入线性层，预测下一个词\n",
    "        # shape = [batch size, output dim]\n",
    "        prediction = self.fc_out(ops.concat((output, weighted, embedded), axis=1))\n",
    "\n",
    "        return prediction, hidden.squeeze(0), a.squeeze(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Seq2Seq\n",
    "\n",
    "Seq2Seq封装器将我们之前创建的编码器与解码器合并起来。\n",
    "\n",
    "简单梳理一下整体过程：\n",
    "\n",
    "1. 初始化空数列`outputs`，用于储存每次的预测结果；\n",
    "2. 源序列$X$作为编码器的输入，输出$z$和$H$；\n",
    "3. 解码器初始时刻的隐藏状态为编码器中输出的上下文向量，即编码器最后时刻的隐藏状态，$s_0 = z = h_T$；\n",
    "4. 解码器最开始的输入$y_1$为表示序列开始的占位符\\<bos\\>;\n",
    "5. 重复以下步骤：\n",
    "    - 将此时刻$t$的输入$y_t$，上一时刻的隐藏状态$s_{t-1}$，编码器中的所有隐藏状态$H$作为输入；\n",
    "    - 输出对下一时刻的预测$\\hat{y}_{t+1}$，以及新的隐藏状态$s_t$；\n",
    "    - 将预测结果存入`outputs`中\n",
    "    - 确定是否使用teacher forcing，如是，$y_{t+1} = \\hat{y}_{t+1}$，如否，下一时刻的输入为目标序列中的词；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import Tensor\n",
    "\n",
    "class Seq2Seq(nn.Cell):\n",
    "    def __init__(self, encoder, decoder, src_pad_idx, teacher_forcing_ratio):\n",
    "        super().__init__()\n",
    "\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.src_pad_idx = src_pad_idx\n",
    "        self.teacher_forcing_ratio = teacher_forcing_ratio  # 使用teacher forcing的可能性\n",
    "\n",
    "    def create_mask(self, src):\n",
    "        \"\"\"标记出每个序列中<pad>占位符的位置\"\"\"\n",
    "        mask = (src != self.src_pad_idx).astype(mindspore.int32).swapaxes(1, 0)\n",
    "        return mask\n",
    "\n",
    "    def construct(self, src, src_len, trg, trg_len=None):\n",
    "        \"\"\"构建seq2seq模型\n",
    "\n",
    "        Args:\n",
    "            src: 源序列；shape = [src len, batch size]\n",
    "            src_len: 源序列长度；shape = [batch size]\n",
    "            trg: 目标序列；shape = [trg len, batch size]\n",
    "            trg_len: 目标序列长度；shape = [trg len, batch size]\n",
    "        \"\"\"\n",
    "        if trg_len is None:\n",
    "            trg_len = trg.shape[0]\n",
    "\n",
    "        #存储解码器输出\n",
    "        outputs = []\n",
    "\n",
    "        # 编码器（encoder）：\n",
    "        # 输入：源序列、源序列长度\n",
    "        # 输出1：编码器中所有前向与反向RNN 的隐藏状态 encoder_outputs\n",
    "        # 输出2：编码器中前向与反向RNN中最后时刻的隐藏状态放入线性层后的输出 hidden\n",
    "        encoder_outputs, hidden = self.encoder(src, src_len)\n",
    "\n",
    "        #解码器的第一个输入是表示序列开始的占位符<bos>\n",
    "        inputs = trg[0]\n",
    "\n",
    "        # 标记源序列中<pad>占位符的位置\n",
    "        # shape = [batch size, src len]\n",
    "        mask = self.create_mask(src)\n",
    "\n",
    "        for t in range(1, trg_len):\n",
    "\n",
    "            # 解码器（decoder）：\n",
    "            # 输入：源句子序列 inputs、前一时刻的隐藏状态 hidden、编码器所有前向与反向RNN的隐藏状态\n",
    "            # 标明每个句子中的<pad>，方便计算注意力权重时忽略该部分\n",
    "            # 输出：预测结果 output、新的隐藏状态 hidden、注意力权重（忽略）\n",
    "            output, hidden, _ = self.decoder(inputs, hidden, encoder_outputs, mask)\n",
    "\n",
    "            # 将预测结果放入之前的存储中\n",
    "            outputs.append(output)\n",
    "\n",
    "            #找出对应预测概率最大的词元\n",
    "            top1 = output.argmax(1).astype(mindspore.int32)\n",
    "\n",
    "            if self.training:\n",
    "                #如果目前为模型训练状态，则按照之前设定的概率使用teacher forcing\n",
    "                minval = Tensor(0, mindspore.float32)\n",
    "                maxval = Tensor(1, mindspore.float32)\n",
    "                teacher_force = ops.uniform((1,), minval, maxval) < self.teacher_forcing_ratio\n",
    "                # 如使用teacher forcing，则将目标序列中对应的词元作为下一个输入\n",
    "                # 如不使用teacher forcing，则将预测结果作为下一个输入\n",
    "                inputs = trg[t] if teacher_force else top1\n",
    "            else:\n",
    "                inputs = top1\n",
    "\n",
    "        # 将所有输出整合为tensor\n",
    "        outputs = ops.stack(outputs, axis=0)\n",
    "\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练\n",
    "\n",
    "模型参数，编码器，注意力层，解码器以及Seq2Seq网络初始化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = len(de_vocab)  # 输入维度\n",
    "output_dim = len(en_vocab)  # 输出维度\n",
    "enc_emb_dim = 256  # Encoder Embedding层维度\n",
    "dec_emb_dim = 256  # Decoder Embedding层维度\n",
    "enc_hid_dim = 512  # Encoder 隐藏层维度\n",
    "dec_hid_dim = 512  # Decoder 隐藏层维度\n",
    "enc_dropout = 0.5  # Encoder Dropout\n",
    "dec_dropout = 0.5  # Decoder Dropout\n",
    "src_pad_idx = de_vocab.pad_idx  # 德语词典中pad占位符的数字索引\n",
    "trg_pad_idx = en_vocab.pad_idx  # 英语词典中pad占位符的数字索引\n",
    "\n",
    "attn = Attention(enc_hid_dim, dec_hid_dim)\n",
    "encoder = Encoder(input_dim, enc_emb_dim, enc_hid_dim, dec_hid_dim, enc_dropout)\n",
    "decoder = Decoder(output_dim, dec_emb_dim, enc_hid_dim, dec_hid_dim, dec_dropout, attn)\n",
    "\n",
    "model = Seq2Seq(encoder, decoder, src_pad_idx, 0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8050, 6198)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_dim, output_dim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "损失函数、优化器初始化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = nn.Adam(model.trainable_params(), learning_rate=0.001)  # 损失函数\n",
    "loss_fn = nn.CrossEntropyLoss(ignore_index=trg_pad_idx)  # 优化器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意在模型训练中，可能会出现权重更新过大的情况。这会导致数值上溢或者下溢，最终造成梯度爆炸（gradient explosion）。为解决这个问题，我们需要在反向传播计算梯度之后，使用梯度裁剪(gradient clipping)，再将裁剪后的梯度传入优化器进行网络更新。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.ops as ops\n",
    "\n",
    "def clip_by_norm(clip_norm, t, axis=None):\n",
    "    \"\"\"给定张量t和裁剪参数clip_norm，对t进行正则化\n",
    "\n",
    "    使得t在axes维度上的L2-norm小于等于clip_norm。\n",
    "\n",
    "    Args:\n",
    "        t: tensor，数据类型为float\n",
    "        clip_norm: scalar，数值需大于0；梯度裁剪阈值，数据类型为float\n",
    "        axis: Union[None, int, tuple(int)]，数据类型为int32；计算L2-norm参考的维度，如为Norm，则参考所有维度\n",
    "    \"\"\"\n",
    "\n",
    "    # 计算L2-norm\n",
    "    t2 = t * t\n",
    "    l2sum = t2.sum(axis=axis, keepdims=True)\n",
    "    pred = l2sum > 0\n",
    "    # 将加和中等于0的元素替换为1，避免后续出现NaN\n",
    "    l2sum_safe = ops.select(pred, l2sum, ops.ones_like(l2sum))\n",
    "    l2norm = ops.select(pred, ops.sqrt(l2sum_safe), l2sum)\n",
    "    # 比较L2-norm和clip_norm，如L2-norm超过阈值，进行裁剪\n",
    "    # 剪裁方法：output(x) = (x * clip_norm)/max(|x|, clip_norm)\n",
    "    intermediate = t * clip_norm\n",
    "    cond = l2norm > clip_norm\n",
    "    t_clip = ops.identity(intermediate / ops.select(cond, l2norm, clip_norm))\n",
    "\n",
    "    return t_clip\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型训练，训练途中使用验证数据集进行验证评估，并保存效果最好的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_fn(src, src_len, trg):\n",
    "    \"\"\"前向网络\"\"\"\n",
    "    src = src.swapaxes(0, 1)\n",
    "    trg = trg.swapaxes(0, 1)\n",
    "\n",
    "    output = model(src, src_len, trg)\n",
    "    output_dim = output.shape[-1]\n",
    "    output = output.view(-1, output_dim)\n",
    "    trg = trg[1:].view(-1)\n",
    "    loss = loss_fn(output, trg)\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "# 反向传播计算梯度\n",
    "grad_fn = mindspore.value_and_grad(forward_fn, None, opt.parameters)\n",
    "\n",
    "def train_step(src, src_len, trg, clip):\n",
    "    \"\"\"单步训练\"\"\"\n",
    "    loss, grads = grad_fn(src, src_len, trg)\n",
    "    grads = ops.HyperMap()(ops.partial(clip_by_norm, clip), grads)  # 梯度裁剪\n",
    "    opt(grads)  # 更新网络参数\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def train(iterator, clip, epoch=0):\n",
    "    \"\"\"模型训练\"\"\"\n",
    "    model.set_train(True)\n",
    "    num_batches = len(iterator)\n",
    "    total_loss = 0  # 所有batch训练loss的累加\n",
    "    total_steps = 0  # 训练步数\n",
    "\n",
    "    with tqdm(total=num_batches) as t:\n",
    "        t.set_description(f'Epoch: {epoch}')\n",
    "        for src, src_len, trg in iterator():\n",
    "            loss = train_step(src, src_len, trg, clip)  # 当前batch的loss\n",
    "            total_loss += loss.asnumpy()\n",
    "            total_steps += 1\n",
    "            curr_loss = total_loss / total_steps  # 当前的平均loss\n",
    "            t.set_postfix({'loss': f'{curr_loss:.2f}'})\n",
    "            t.update(1)\n",
    "\n",
    "    return total_loss / total_steps\n",
    "\n",
    "\n",
    "def evaluate(iterator):\n",
    "    \"\"\"模型验证\"\"\"\n",
    "    model.set_train(False)\n",
    "    num_batches = len(iterator)\n",
    "    total_loss = 0  # 所有batch训练loss的累加\n",
    "    total_steps = 0  # 训练步数\n",
    "\n",
    "    with tqdm(total=num_batches) as t:\n",
    "        for src, src_len, trg in iterator():\n",
    "            loss = forward_fn(src, src_len, trg)  # 当前batch的loss\n",
    "            total_loss += loss.asnumpy()\n",
    "            total_steps += 1\n",
    "            curr_loss = total_loss / total_steps  # 当前的平均loss\n",
    "            t.set_postfix({'loss': f'{curr_loss:.2f}'})\n",
    "            t.update(1)\n",
    "\n",
    "    return total_loss / total_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 0: 100%|███████████████████████████████████████████████████████████████████████▋| 225/226 [11:20<00:33, 33.07s/it, loss=4.87]"
     ]
    }
   ],
   "source": [
    "from mindspore import save_checkpoint\n",
    "\n",
    "num_epochs = 10  # 训练迭代数\n",
    "clip = 1.0  # 梯度裁剪阈值\n",
    "best_valid_loss = float('inf')  # 当前最佳验证损失\n",
    "ckpt_file_name = os.path.join(cache_dir, 'seq2seq.ckpt')  # 模型保存路径\n",
    "\n",
    "for i in range(num_epochs):\n",
    "    # 模型训练，网络权重更新\n",
    "    train_loss = train(train_iterator, clip, i)\n",
    "    # 网络权重更新后对模型进行验证\n",
    "    valid_loss = evaluate(valid_iterator)\n",
    "\n",
    "    # 保存当前效果最好的模型\n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        save_checkpoint(model, ckpt_file_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate_sentence(sentence, de_vocab, en_vocab, model, max_len=32):\n",
    "    \"\"\"给定德语句子，返回英文翻译\"\"\"\n",
    "    model.set_train(False)\n",
    "    # 对输入句子进行分词\n",
    "    if isinstance(sentence, str):\n",
    "        spacy_lang = spacy.load('de')\n",
    "        tokens = [token.text.lower() for token in spacy_lang(sentence)]\n",
    "    else:\n",
    "        tokens = [token.lower() for token in sentence]\n",
    "\n",
    "    # 补充起始、终止占位符，统一序列长度\n",
    "    if len(tokens) > max_len - 2:\n",
    "        src_len = max_len\n",
    "        tokens = ['<bos>'] + tokens[:max_len - 2] + ['<eos>']\n",
    "    else:\n",
    "        src_len = len(tokens) + 2\n",
    "        tokens = ['<bos>'] + tokens + ['<eos>'] + ['<pad>'] * (max_len - src_len)\n",
    "\n",
    "    # 将德语单词转化为数字索引\n",
    "    src = de_vocab.encode(tokens)\n",
    "    src = mindspore.Tensor(src, mindspore.int32).expand_dims(1)\n",
    "    src_len = mindspore.Tensor([src_len], mindspore.int32)\n",
    "    trg = mindspore.Tensor([en_vocab.bos_idx], mindspore.int32).expand_dims(1)\n",
    "\n",
    "    # 获得预测结果，并将其转化为英语单词\n",
    "    outputs = model(src, src_len, trg, max_len)\n",
    "    trg_indexes = [int(i.argmax(1).asnumpy()) for i in outputs]\n",
    "    eos_idx = trg_indexes.index(en_vocab.eos_idx) if en_vocab.eos_idx in trg_indexes else -1\n",
    "    trg_tokens = en_vocab.decode(trg_indexes[:eos_idx])\n",
    "\n",
    "    return trg_tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用测试数据集中的任意一组文本数据进行预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import load_checkpoint, load_param_into_net\n",
    "\n",
    "# 加载之前训练好的模型\n",
    "param_dict = load_checkpoint(ckpt_file_name)\n",
    "load_param_into_net(model, param_dict)\n",
    "\n",
    "# 以测试数据集中的第一组语句为例，进行测试\n",
    "example_idx = 0\n",
    "\n",
    "src = test_dataset[example_idx][0]\n",
    "trg = test_dataset[example_idx][1]\n",
    "\n",
    "print(f'src = {src}')\n",
    "print(f'trg = {trg}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看预测结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "translation = translate_sentence(src, de_vocab, en_vocab, model)\n",
    "\n",
    "print(f'predicted trg = {translation}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BLEU得分\n",
    "\n",
    "双语替换评测得分（bilingual evaluation understudy，BLEU）为衡量文本翻译模型生成出来的语句好坏的一种算法，它的核心在于评估机器翻译的译文 $\\text{pred}$ 与人工翻译的参考译文 $\\text{label}$ 的相似度。通过对机器译文的片段与参考译文进行比较，计算出各个片段的的分数，并配以权重进行加和，基本规则为：\n",
    "\n",
    "1. 惩罚过短的预测，即如果机器翻译出来的译文相对于人工翻译的参考译文过于短小，则命中率越高，需要施加更多的惩罚；\n",
    "2. 对长段落匹配更高的权重，即如果出现长段落的完全命中，说明机器翻译的译文更贴近人工翻译的参考译文；\n",
    "\n",
    "BLEU的公式如下：\n",
    "\n",
    "$$exp(min(0, 1-\\frac{len(\\text{label})}{len(\\text{pred})})\\Pi^k_{n=1}p_n^{1/2^n})$$\n",
    "\n",
    "- `len(label)`：人工翻译的译文长度\n",
    "- `len(pred)`：机器翻译的译文长度\n",
    "- `p_n`：n-gram的精度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk.translate.bleu_score import corpus_bleu\n",
    "\n",
    "def calculate_bleu(dataset, de_vocab, en_vocab, model, max_len=50):\n",
    "    trgs = []\n",
    "    pred_trgs = []\n",
    "\n",
    "    for data in dataset:\n",
    "\n",
    "        src = data[0] # 源语句：德语\n",
    "        trg = data[1] # 目标语句：英语\n",
    "\n",
    "        # 获取模型预测结果\n",
    "        pred_trg = translate_sentence(src, de_vocab, en_vocab, model, max_len)\n",
    "        pred_trgs.append(pred_trg)\n",
    "        trgs.append([trg])\n",
    "\n",
    "    return corpus_bleu(trgs, pred_trgs)\n",
    "\n",
    "# 计算BLEU Score\n",
    "bleu_score = calculate_bleu(test_dataset, de_vocab, en_vocab, model)\n",
    "\n",
    "print(f'BLEU score = {bleu_score*100:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参考文献\n",
    "\n",
    "Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. 2014. Neural machine translation by jointly learning\n",
    "to align and translate. arXiv preprint arXiv:1409.0473."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "cadaf88f1476df1afa3bcea73fea58f4caea9e659abbd9272448eb6df92fe30a"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
